<!DOCTYPE html>
<html>
<head>
  <script src="face-api.js"></script>
  <script src="commons.js"></script>
  <script src="js/randomCrop.js"></script>
  <script src="js/indexedDb.js"></script>
  <script src="FileSaver.js"></script>
</head>
<body>
  <div id="container"></div>

  <script>
    tf = faceapi.tf

    // uri to weights file of last checkpoint
    window.net = new faceapi.FaceLandmark68LargeNet()
    const modelCheckpoint = '/tmp/initial_glorot.weights'
    const startEpoch = 0

    const learningRate = 0.001 // 0.001
    window.optimizer = tf.train.adam(learningRate, 0.9, 0.999, 1e-8)

    window.saveEveryNthSample = 50000000

    window.withRandomCrop = false

    window.batchSize = 12

    window.lossValue = 0

    window.drawResults = true

    const log = (str, ...args) => console.log(`[${[(new Date()).toTimeString().substr(0, 8)]}] ${str || ''}`, ...args)

    function saveWeights(net, filename = 'train_tmp') {
      saveAs(new Blob([net.serializeParams()]), filename)
    }

    async function loadNetWeights(uri) {
      return new Float32Array(await (await fetch(uri)).arrayBuffer())
    }

    async function fetchTrainDataFilenames() {
      return fetch('/face_landmarks_train_ids').then(res => res.json())
    }

    async function load() {
      window.ptsFiles = await fetchTrainDataFilenames()

      window.ptsFiles = window.ptsFiles.slice(1000, 1000 + window.batchSize)
      //window.ptsFiles = window.ptsFiles.slice(7, 8)

      const weights = await loadNetWeights(modelCheckpoint)
      await window.net.load(weights)
      window.net.variable()
    }

    function randomCrop(img, pts) {
      const xCoords = pts.map(pt => pt.x)
      const yCoords = pts.map(pt => pt.y)
      const minX = xCoords.reduce((x, min) => x < min ? x : min, img.width)
      const minY = yCoords.reduce((y, min) => y < min ? y : min, img.height)
      const maxX = xCoords.reduce((x, max) => x < max ? max : x, 0)
      const maxY = yCoords.reduce((y, max) => y < max ? max : y, 0)

      const x0 = Math.random() * minX
      const y0 = Math.random() * minY

      const x1 = (Math.random() * Math.abs(img.width - maxX)) + maxX
      const y1 = (Math.random() * Math.abs(img.height - maxY)) + maxY

      const targetCanvas = faceapi.createCanvas({ width: x1 - x0, height: y1 - y0 })
      const inputCanvas = img instanceof HTMLCanvasElement ? img : faceapi.createCanvasFromMedia(img)
      const region = faceapi.getContext2dOrThrow(inputCanvas)
        .getImageData(x0, y0, targetCanvas.width, targetCanvas.height)
      faceapi.getContext2dOrThrow(targetCanvas).putImageData(region, 0, 0)

      return {
        croppedImage: targetCanvas,
        shiftedPts: pts.map(pt => new faceapi.Point(pt.x - x0, pt.y - y0))
      }
    }

    async function train() {
      await load()

      for (let epoch = startEpoch; epoch < Infinity; epoch++) {

        if (epoch !== startEpoch) {
          //saveWeights(window.net, `landmarks_epoch${epoch - 1}.weights`)
          //saveAs(new Blob([JSON.stringify({ loss: window.lossValue, avgLoss: window.lossValue / 20000 })]), `landmarks_epoch${epoch - 1}.json`)
          window.lossValue = 0
        }

        const shuffledInputs = window.drawResults ? window.ptsFiles : faceapi.shuffleArray(window.ptsFiles)

        for (let dataIdx = 0; dataIdx < shuffledInputs.length; dataIdx += window.batchSize) {

          const ids = shuffledInputs.slice(dataIdx, dataIdx + window.batchSize)

          // fetch image and ground truth bounding boxes
          console.time('getPts')
          const allPts = (await getPts(ids)).map(res => res.pts)
          console.timeEnd('getPts')
          console.time('getJpgs')
          const allJpgs = await Promise.all(
            (await getJpgs(ids)).map(res => res.blob).map(blob => faceapi.bufferToImage(blob))
          )
          console.timeEnd('getJpgs')

          const bSquaredImages = []
          const bCroppedImages = []
          const bImages = []
          const bPts = []

          for (let batchIdx = 0; batchIdx < ids.length; batchIdx += 1) {

            const pts = allPts[batchIdx]
            const img = allJpgs[batchIdx]

            const { croppedImage, shiftedPts } = window.withRandomCrop
              ? randomCrop(img, pts)
              : { croppedImage: img, shiftedPts: pts }

            const squaredImg = faceapi.imageToSquare(croppedImage, 112, true)

            const reshapedDimensions = faceapi.computeReshapedDimensions(croppedImage, 112)

            const pad = Math.abs(reshapedDimensions.width - reshapedDimensions.height) /  (2 * 112)
            const padX = reshapedDimensions.width < reshapedDimensions.height ? pad : 0
            const padY = reshapedDimensions.height < reshapedDimensions.width ? pad : 0

            const groundTruthLandmarks = shiftedPts.map(pt => new faceapi.Point(
              padX + (pt.x / croppedImage.width) * (reshapedDimensions.width / 112),
              padY + (pt.y / croppedImage.height)* (reshapedDimensions.height / 112)
            ))

            bSquaredImages.push(squaredImg)
            bCroppedImages.push(croppedImage)
            bImages.push(squaredImg)
            bPts.push(groundTruthLandmarks)
          }

          const netInput = await faceapi.toNetInput(bImages)

          const ts = Date.now()

          const loss = optimizer.minimize(() => {
            console.time('runNet')
            const out = window.net.runNet(netInput)
            console.timeEnd('runNet')

            if (window.drawResults) {
              // debug
              const landmarkTensors = tf.unstack(out)
              landmarkTensors.forEach((t, i) => {
                const croppedImage = bSquaredImages[i]


                const landmarksArray = Array.from(t.dataSync())
                const xCoords = landmarksArray.filter((_, i) => faceapi.isEven(i))
                const yCoords = landmarksArray.filter((_, i) => !faceapi.isEven(i))
 /*
                const landmarksArray = bPts[i]
                const xCoords = landmarksArray.map(pt => pt.x)
                const yCoords = landmarksArray.map(pt => pt.y)
*/
                const marks = new faceapi.FaceLandmarks68(
                  Array(68).fill(0).map((_, i) => new faceapi.Point(xCoords[i], yCoords[i])),
                  croppedImage
                )

                const canvas = croppedImage instanceof HTMLCanvasElement ? croppedImage : faceapi.createCanvasFromMedia(croppedImage)
                faceapi.drawLandmarks(canvas, marks, { drawLines: true })
                const container = document.getElementById('container')
                if (container.children[i]) {
                  const oldChild = container.children[i]
                  container.insertBefore(canvas, oldChild)
                  container.removeChild(oldChild)
                } else {
                  container.appendChild(canvas)
                }
              })
              // debug
            }

            const landmarksTensor = tf.tensor2d(
              bPts
                .reduce((flat, arr) => flat.concat(arr))
                .map(pt => [pt.x, pt.y])
                .reduce((flat, arr) => flat.concat(arr)),
              [window.batchSize, 136]
            )

            const loss = tf.losses.meanSquaredError(
              landmarksTensor,
              out,
              tf.Reduction.MEAN
            )

            //return tf.sum(tf.abs(tf.sub(landmarksTensor, out)))
            return loss
          }, true)

          // start next iteration without waiting for loss data
          loss.data().then(data => {
            const lossValue = data[0]
            window.lossValue += lossValue
            log(`epoch ${epoch}, dataIdx ${dataIdx} - loss: ${lossValue}`)
            loss.dispose()
          })

          if (window.iterDelay) {
            await delay(window.iterDelay)
          } else {
            await tf.nextFrame()
          }
        }
      }
    }

  </script>
</body>
</html>